#------------------------------------------------------------------------------
# set general parameters
# ------------------------------------------------------------------------------

# path where the data file are stored
data.dir = file.path("build","data")

# EIA API Key
# get one at: http://www.eia.gov/opendata/register.cfm
eia.api.key = "CBE2656F80FDF13628A9BA70327CA042"



# ------------------------------------------------------------------------------
# load libraries
# ------------------------------------------------------------------------------

library(ggplot2)
library(mapproj)
library(maps)
library(httr)
library(RJSONIO)



# ------------------------------------------------------------------------------
# functions for running the model
# ------------------------------------------------------------------------------

# function to run the model
run.model = function(model="sage.gms",
                     policy_file=NA,
                     output_file="results.csv",
                     gdx_baseline_file=NA,
                     proximal_perturbation=NA,
                     gdx_save=1,
                     benchmark_file="default_aggregation.gdx",
                     parameter_file="parameters.gms",
                     options="") {

  # results gdx file
  results_gdx = paste0(strsplit(output_file, "\\.")[[1]][1],".gdx")

  # basic command to run the model
  command = paste0("gams model\\\\",model,
                   " --output_file=output\\\\",output_file,
                   " --gdx_results_file=output\\\\",results_gdx,
                   " --gdx_save=",gdx_save,
                   " --benchmark_file=data\\\\",benchmark_file,
                   " --parameter_file=data\\\\",parameter_file,
                   " ",options)

  # if a policy file is provided then include that in the command
  if (!is.na(policy_file))
    command = paste0(command," --policy_file=",policy_file)

  # if a baseline file is provided then include that in the command
  if (!is.na(gdx_baseline_file))
  command = paste0(command," --gdx_baseline_file=output\\\\",gdx_baseline_file)

  # if a proximal_perturbation value is provided then include that in the command
  if (!is.na(proximal_perturbation))
    command = paste0(command," --proximal_perturbation=",proximal_perturbation)

  # run the model
  system(command)

}



# ------------------------------------------------------------------------------
# functions for working with model output
# ------------------------------------------------------------------------------

# Return aggregate national value of a parameter
agg.value = function(x,p,t) sum(as.numeric(x$value[x$parameter==p & x$time==t]))



# Return difference in national aggregate values of a parameter
agg.diff = function(a,b,p,t) agg.value(a,p,t)-agg.value(b,p,t)



# Return % difference in national aggregate values of a parameter
agg.pct = function(a,b,p,t) agg.diff(a,b,p,t)/agg.value(b,p,t)*100



# combine a variable across new and extant capital
combine.new.extant = function(results,parameter,sector=NA) {
  temp = results[results$parameter==parameter,]
  if (length(results$value[results$parameter==paste(parameter,"_ex")])>0)
    temp$value = temp$value+results$value[results$parameter==paste(parameter,"_ex")]
  if (!is.na(sector))
    temp = temp[temp$sector==sector,]
  return(temp)
}



# ------------------------------------------------------------------------------
# output functions
# ------------------------------------------------------------------------------

# printf capability
printf <- function(...) invisible(print(sprintf(...)))



# Print out the results
compare = function(p,t) {
  printf('%s Change in %4.0f....: %7.3f (%7.3f%%)',p,t,agg.diff(policy,bau,p,t),
         agg.pct(policy,bau,p,t))
}



# function to write a variable to a gams input file
write.gams.variable = function(file,variable.name,variable,append=FALSE,
                               comments=NULL,write.zeros=FALSE,
                               force.index=FALSE) {

  # set the mode to open the file in
  if (append)
    open.mode = "at"
  else
    open.mode = "wt"

  # open the file for writting if its not already an open connection
  if (length(class(file))==2 & class(file)[2]=="connection")
    fc = file
  else
    fc = file(file,open=open.mode)

  # write any comments to the input file
  if (any(!is.null(comments))) {
    for (i in comments)
      writeLines(paste0("* ",i),con=fc)
  }

  # get row and column dimensions and their names
  num.rows = nrow(variable)
  num.cols = ncol(variable)

  if (is.null(num.rows) & is.null(num.cols) & length(variable)==1) {
    num.rows = 1
    num.cols = 1
  }

  # loop through the data and write the variable line by line
  for (row in 1:num.rows) {

    line = variable.name

    for (col in 1:num.cols) {

      if (num.cols>1 & col<num.cols) {

        if (col==1)
          line = paste0(line,"(")

        line = paste0(line,"\"",variable[row,col],"\"")

        if (col==num.cols-1)
          line = paste0(line,")")
        else
          line = paste0(line,",")

      }

    }

    # if the variable is a scalar it can't handle the indexing
    if (length(variable)==1)
      value = variable
    else
      value = variable[row,col]

    line = paste0(line," = ",value,";")

    if (write.zeros | value!=0)
      writeLines(line,con=fc)

  }

  # skip a line to make the file more readible
  writeLines("",con=fc)

  # close the file connection
  if (!(length(class(file))==2 & class(file)[2]=="connection"))
    close(fc)

}



# ------------------------------------------------------------------------------
# function to get years represented in model
# ------------------------------------------------------------------------------

get.sage.years = function(build.file="build_aeo_baseline.gms") {
  lines = readLines(file.path("build",build.file))
  for (i in 1:length(lines)) {
    if (nchar(lines[i])>33 & substr(lines[i],1,36)=="  t                     time periods") {
      values = substr(lines[i],regexpr("/",lines[i])[1]+1,nchar(lines[i]))
      for (j in (i+1):length(lines)) {
        stop = gregexpr("/*(/|$)",lines[j])[[1]][1]
        if (stop==-1) {
          values = paste0(values,lines[j])
        } else {
          values = paste0(values,substr(lines[j],1,stop-1))
          break
        }
      }
      break
    }
  }
  return(as.numeric(strsplit(values,",")[[1]]))
}



# ------------------------------------------------------------------------------
# functions for working with household aggregation
# ------------------------------------------------------------------------------

# load the household aggregation from the gams mapping file
load.household.aggregation = function(aggregation.file="default_aggregation.gms") {

  # load the aggregation mapping file
  map.file = readLines(file.path("build","aggregation_map",aggregation.file))

  # line on which the mapping starts
  ind = which(sapply(map.file,function(x) grepl("h_map(h_a,h)",x,fixed=TRUE)))+1

  # intialize space for the mapping
  map = NULL

  # loop through the file until the end of the mapping
  while (!grepl(";",map.file[ind],fixed=TRUE)) {

    # if the line represents a new aggregate region then extract the name
    if (grepl(".(",map.file[ind],fixed=TRUE)) {
      household = trimws(substr(map.file[ind],1,
                             regexpr(".(",map.file[ind],fixed=TRUE)-1))
    }

    # else check if the line contains a subhousehold and add it to map
    else {
      temp = data.frame(h=household,stringsAsFactors=FALSE)
      temp$subhousehold = substr(trimws(map.file[ind]),1,5)
      if (temp$subhousehold[1] != ")")
        map = rbind(map,temp)
    }

    # increment the line number
    ind = ind+1

  }

  return(map)

}



# return the list of households in the model
get.sage.households = function(aggregation.file="default_aggregation.gms")
  return(unique(load.household.aggregation(aggregation.file)$h))



# ------------------------------------------------------------------------------
# functions for working with sectoral aggregation
# ------------------------------------------------------------------------------

# load the sectoral aggregation from the gams mapping file
load.sectoral.aggregation = function(aggregation.file="default_aggregation.gms") {

  # load the aggregation mapping file
  map.file = readLines(file.path("build","aggregation_map",aggregation.file))

  # line on which the mapping starts
  ind = which(sapply(map.file,function(x) grepl("s_map(s_a,s)",x,fixed=TRUE)))+1

  # intialize space for the mapping
  map = NULL

  # loop through the file until the end of the mapping
  while (!grepl(";",map.file[ind],fixed=TRUE)) {

    # if the line represents a new aggregate region then extract the name
    if (grepl(".(",map.file[ind],fixed=TRUE)) {
      sector = trimws(substr(map.file[ind],1,
                             regexpr(".(",map.file[ind],fixed=TRUE)-1))
    }

    # else check if the line contains a subsector and add it to map
    else {
      temp = data.frame(s=sector,stringsAsFactors=FALSE)
      temp$subsector = substr(trimws(map.file[ind]),1,3)
      if (temp$subsector[1] != ")")
        map = rbind(map,temp)
    }

    # increment the line number
    ind = ind+1

  }

  return(map)

}



# return the list of sectors in the model
get.sage.sectors = function(aggregation.file="default_aggregation.gms")
  return(unique(load.sectoral.aggregation(aggregation.file)$s))



# ------------------------------------------------------------------------------
# functions for working with regional aggregation
# ------------------------------------------------------------------------------

# load the regional aggregation from the gams mapping file
load.regional.aggregation = function(aggregation.file="default_aggregation.gms") {

  # load the aggregation mapping file
  map.file = readLines(file.path("build","aggregation_map",aggregation.file))

  # line on which the regional mapping starts
  ind = which(sapply(map.file,function(x) grepl("r_map(r_a,r)",x,fixed=TRUE)))+1

  # intialize space for the mapping
  map = NULL

  # loop through the file until the end of the regional mapping
  while (!grepl(";",map.file[ind],fixed=TRUE)) {

    # if the line represents a new aggregate region then extract the name
    if (grepl(".(",map.file[ind],fixed=TRUE)) {
      region = trimws(substr(map.file[ind],1,
                             regexpr(".(",map.file[ind],fixed=TRUE)-1))
    }

    # else check if the line contains a state abbreviation and add it to map
    else {
      temp = data.frame(r=region,stringsAsFactors=FALSE)
      temp$state = substr(trimws(map.file[ind]),1,2)
      if (toupper(temp$state[1]) %in% c(state.abb,"DC"))
        map = rbind(map,temp)
    }

    # increment the line number
    ind = ind+1

  }

  return(map)

}



# return the list of regions in the model
get.sage.regions = function(aggregation.file="default_aggregation.gms")
  return(unique(load.regional.aggregation(aggregation.file)$r))



# aggregates state data to the model regions
aggregate.state.data = function(data,map=NA,
                                aggregation.file="default_aggregation.gms") {

  # if no map is provided load the default mapping
  if (is.na(map))
    map = load.regional.aggregation(aggregation.file=aggregation.file)

  # add the data to the regional mapping
  data = merge(map,data,by="state",all.x=TRUE)

  # aggregate the data across the regions
  data = aggregate(data$value,by=list(data$r),sum,na.rm=TRUE)

  # rename the columns
  names(data) = c("r","value")

  return(data)

}



# function to plot state level choropleth maps
state.choropleth = function(map.data,aggregation.file="default_aggregation.gms",
                            plot.title="",legend.title="",num.colors=9,
                            palette="Blues",border.color=NA,discrete=FALSE) {

  # load required libraries
  library(maps)
  library(fiftystater)
  library(ggplot2)
  library(RColorBrewer)

  # load the map polygons
  data("fifty_states")

  # load the state to region mapping for the model
  region.map = load.regional.aggregation(aggregation.file)

  # load the state name to abbreviation mapping
  states.map = data.frame(abbreviation=tolower(state.abb),
                          name=tolower(state.name))

  # add the district of columbia
  states.map = rbind(states.map,data.frame(abbreviation="dc",
                                           name="district of columbia"))

  # add state names to the data
  states.map$value = sapply(states.map$abbreviation,
                            function(x) map.data$value[map.data$region==region.map$r[region.map$state==x]])

  # color for the state borders
  if (is.na(border.color))
    border.color = rgb(222,222,222,max=255)

  # create the ggplot object
  p = ggplot(states.map,aes(map_id=name))+
    geom_map(aes(fill=value),map=fifty_states,color=border.color)+
    expand_limits(x=fifty_states$long,y=fifty_states$lat)+
    coord_map()+
    scale_x_continuous(breaks=NULL)+
    scale_y_continuous(breaks=NULL)+
    labs(x="",y="")+
    ggtitle(plot.title)

  if (discrete) {
    p = p+scale_fill_manual(name=legend.title,values=brewer.pal(num.colors,palette),
                            labels=map.data$region[as.numeric(map.data$value)],drop=TRUE)+
      guides(fill=guide_legend(override.aes=list(linetype=1,size=1,color="white")))
  } else {
    p = p+scale_fill_gradientn(name=legend.title,colors=brewer.pal(num.colors,palette))
  }

  p = p+theme(text              = element_text(family="Helvetica"),
              axis.line         = element_blank(),
              panel.grid.major  = element_blank(),
              panel.grid.minor  = element_blank(),
              panel.border      = element_blank(),
              panel.background  = element_blank(),
              axis.text.x       = element_blank(),
              axis.title.x      = element_blank(),
              axis.ticks.x      = element_blank(),
              axis.title.y      = element_blank(),
              axis.text.y       = element_blank(),
              axis.ticks.y      = element_blank(),
              plot.title        = element_text(size=rel(0.9),margin=margin(0,0,10,0)),
              legend.direction  = "vertical",
              legend.position   = c(0.9,0.3),
              legend.key        = element_rect(size=0,color="white"),
              legend.key.height = unit(.2,"inch"),
              legend.key.width  = unit(.35,"inch")
             )

  print(p)

  return(p)

}



# plots the regions in sage
plot.sage.regions = function(aggregation.file="default_aggregation.gms") {

  regions = get.sage.regions(aggregation.file=aggregation.file)

  map.data = data.frame(region=regions,value=rev(as.character(1:length(regions))))

  state.choropleth(map.data,aggregation.file=aggregation.file,plot.title="",
                   legend.title="",num.colors=length(regions),palette="Set3",
                   border.color="#000000",discrete=TRUE)

}



# ------------------------------------------------------------------------------
# functions for working with EIA data
# ------------------------------------------------------------------------------

# function to get the EIA data
get.eia = function(series.id,key=eia.api.key,year=NULL) {

  # url for the EIA API
  eia.url = "http://api.eia.gov/series/"

  # create the API request
  req = paste0(eia.url,"?api_key=",key,"&series_id=",series.id)

  # get the data
  json.data = fromJSON(content(GET(req),as="text",encoding="UTF-8"))$series[[1]]$data

  # set any null values to NA
  json.data = lapply(json.data,function(x) lapply(x,function(y) if (is.null(y)) NA else y))

  # reformat data as a two column matrix
  data = t(matrix(as.numeric(unlist(json.data)),nrow=2))

  # add all the data to a data frame
  eia.data = data.frame(year=data[,1],value=data[,2])

  # if specific years are requested only return those
  if (!is.null(year))
    eia.data = eia.data[eia.data$year %in% year,]

  return(eia.data)

}



# ------------------------------------------------------------------------------
# functions for working with census data
# ------------------------------------------------------------------------------

# use the census api to download state level trade data
get.trade.data = function(hs.commodity,direction="exports") {

  # census api url
  url = paste0("https://api.census.gov/data/timeseries/intltrade/",
               direction,"/statehs")

  # commodity parameter I_COMMODITY for imports and E_COMMODITY for exports
  com.var = paste0(toupper(substr(direction,1,1)),"_COMMODITY")

  # value parameter based on exports or imports
  if (direction=="exports")
    variable = "ALL_VAL_YR"
  else
    variable = "GEN_VAL_YR"

  # create the API request
  req = paste0(url,"?get=STATE,",com.var,",",variable,"&time=",year,"-12&",
               com.var,"=",hs.commodity)

  # get the data
  json.data = fromJSON(content(GET(req),as="text"))

  # reformat the data as a data frame
  data = as.data.frame(t(matrix(unlist(json.data),nrow=5))[-c(1,2),c(1,3)],
                       stringsAsFactors=FALSE)
  data[,2] = as.numeric(data[,2])
  names(data) = c("state","value")

  # convert state abbreviations to lower case
  data$state = tolower(data$state)

  # convert to billion dollars
  data$value = data$value*1e-9

  return(data)

}
